Skip to content

[https://nvbugs/6029882][fix] Clamp tokens_info writes in computeSeqAndPaddingOffsets#13544

Open
bobboli wants to merge 3 commits intoNVIDIA:mainfrom
bobboli:fix/nvbug-6029882-tokens-info-bounds
Open

[https://nvbugs/6029882][fix] Clamp tokens_info writes in computeSeqAndPaddingOffsets#13544
bobboli wants to merge 3 commits intoNVIDIA:mainfrom
bobboli:fix/nvbug-6029882-tokens-info-bounds

Conversation

@bobboli
Copy link
Copy Markdown
Collaborator

@bobboli bobboli commented Apr 28, 2026

Summary

  • Bound each tokensInfo write in computeSeqAndPaddingOffsets to the num_tokens-sized allocation, so the last block cannot scribble past the buffer when the packed seqQLengths sum exceeds numTokens under multi-token generation paths.
  • The std::max(numTokens, seqEnd) extension is preserved for the legitimate sum < numTokens case (filling trailing pad slots); the per-write bound now keeps sum > numTokens from causing OOB.
  • Unwaive accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_mtp].

Why

NVBug 6029882. Primary fault is a CUDA Warp MMU exception in computeSeqAndPaddingOffsets<__nv_bfloat16, 256>, with secondary moeA2ADispatchKernel SIGTRAPs on peer ranks waiting on dispatch metadata that never arrived. Core-dump triage attributed the fault to the last-block std::max(numTokens, seqEnd) branch flagged by the original // FIXME(Eagle) comment.

Test plan

  • Built RelWithDebInfo with -lineinfo locally.
  • Ran the unwaived test on 8x B200 with CUDA_ENABLE_COREDUMP_ON_EXCEPTION=1. Pass: MMLU 86.96 / GSM8K 95.41 / CnnDailymail 27.89, no core dumps.
  • CI reruns covering GB200-8_GPUs-2_Nodes-PyTorch-2 (the original failing stage).

Summary by CodeRabbit

  • Bug Fixes

    • Enhanced memory safety in tensor computation kernels with improved bounds checking to prevent out-of-range memory access during token processing operations.
    • Corrected token indexing logic to ensure all computations remain within valid ranges throughout execution.
  • Tests

    • Re-enabled previously skipped multi-GPU performance throughput test case for improved optimization verification and overall test coverage.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 28, 2026

📝 Walkthrough

Walkthrough

Two independent changes are made: a bounds-checking fix in the CUDA kernel's tokensInfo population logic to prevent out-of-range writes, and a test waive list update that removes a skip directive for a specific test case.

Changes

Cohort / File(s) Summary
CUDA Kernel Bounds Fix
cpp/tensorrt_llm/kernels/gptKernels.cu
Added per-element bounds check to tokensInfo write operations using idx < params.numTokens condition. Prevents out-of-range writes and aligns downstream reads to valid token range. Comments updated to reflect safer behavior, and copyright year range updated.
Test Waive List Update
tests/integration/test_lists/waives.txt
Removed SKIP directive for TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_mtp], indicating the associated issue is resolved.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~8 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely identifies the main fix: clamping tokensInfo writes in computeSeqAndPaddingOffsets to prevent buffer overflow, matching the core change in gptKernels.cu and the NVBugs reference.
Description check ✅ Passed The description comprehensively explains the issue (buffer overflow from seqQLengths exceeding numTokens), the solution (per-write bounds check), the root cause (NVBug 6029882), and test results with metrics; however, the PR Checklist section is incomplete with most items unchecked/unaddressed.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
⚔️ Resolve merge conflicts
  • Resolve merge conflict in branch fix/nvbug-6029882-tokens-info-bounds

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@cpp/tensorrt_llm/kernels/gptKernels.cu`:
- Around line 242-244: The compute_padding_offset lambda writes to
paddingOffsets and encoderPaddingOffsets without guarding against seqBegin +
tokenIdx exceeding the allocated numTokens, which can cause OOB in MTP paths;
update the lambda (compute_padding_offset) to check the destination index
against params.numTokens (or the kernel's numTokens parameter) before writing,
e.g. only assign to paddingOffsets[seqBegin + tokenIdx] and
encoderPaddingOffsets[...] when (seqBegin + tokenIdx) < params.numTokens,
mirroring the existing tokensInfo check to keep kernel-safe bounds.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: b5b78eb0-a59d-4cb3-8ca5-02289a69273d

📥 Commits

Reviewing files that changed from the base of the PR and between 0b9dfdc and cd62807.

📒 Files selected for processing (2)
  • cpp/tensorrt_llm/kernels/gptKernels.cu
  • tests/integration/test_lists/waives.txt
💤 Files with no reviewable changes (1)
  • tests/integration/test_lists/waives.txt

Comment thread cpp/tensorrt_llm/kernels/gptKernels.cu Outdated
Comment on lines +242 to +244
// Iterate over the tokens. Each write is bounded by the tokensInfo allocation
// (sized to numTokens on the caller side); the bound keeps the kernel safe even if
// the caller's seqQLengths sum exceeds numTokens under multi-token generation paths.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Verify how BuildDecoderInfoParams is populated at call sites.
# Expected: in any path where numTokens can be < sum(seqQLengths),
# paddingOffsets/encoderPaddingOffsets should be null, or code should clamp writes similarly.

rg -n --type=cpp -C4 \
'BuildDecoderInfoParams<|\.numTokens\s*=|\.seqQLengths\s*=|\.paddingOffsets\s*=|\.encoderPaddingOffsets\s*=|invokeBuildDecoderInfo\s*\('

# Inspect MTP / multi-token related paths for those assignments.
rg -n --type=cpp -C4 \
'multi[_ ]?token|multi_query_tokens|mtp|throughput_mtp|isBuildDecoderInfoKernelNeeded'

Repository: NVIDIA/TensorRT-LLM

Length of output: 50376


🏁 Script executed:

# Find the compute_padding_offset kernel implementation and nearby bounds checking
rg -n --type=cuda 'compute_padding_offset|paddingOffsets\[' cpp/tensorrt_llm/kernels/gptKernels.cu -A 5 -B 5 | head -150

Repository: NVIDIA/TensorRT-LLM

Length of output: 92


🏁 Script executed:

# Find the compute_padding_offset kernel implementation
rg -n 'compute_padding_offset|paddingOffsets\[' cpp/tensorrt_llm/kernels/gptKernels.cu -A 5 -B 5 | head -200

Repository: NVIDIA/TensorRT-LLM

Length of output: 1368


🏁 Script executed:

# Now check where paddingOffsets/encoderPaddingOffsets are allocated and their expected sizes
# Look at the call sites more carefully to see buffer allocation
rg -n 'padding_offset|paddingOffsets' cpp/tensorrt_llm/thop/trtllmGenQKVProcessOp.cpp -B 3 -A 3

Repository: NVIDIA/TensorRT-LLM

Length of output: 1645


🏁 Script executed:

# Let's look at the tokensInfo write and the bounds guard mentioned in the review
rg -n 'tokensInfo' cpp/tensorrt_llm/kernels/gptKernels.cu -A 10 -B 5 | head -100

Repository: NVIDIA/TensorRT-LLM

Length of output: 2087


🏁 Script executed:

# Let's check the exact bounds issue: compare line 212 (paddingOffsets write) with line 248 (tokensInfo write)
sed -n '190,260p' cpp/tensorrt_llm/kernels/gptKernels.cu

Repository: NVIDIA/TensorRT-LLM

Length of output: 3138


🏁 Script executed:

# Check buffer size allocations - look for how padding_offset tensor is sized
rg -n 'padding_offset.*size\|num_tokens' cpp/tensorrt_llm/thop/trtllmGenQKVProcessOp.cpp -B 2 -A 2

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Search for where padding_offset tensor is created/sized in torch ops
rg -n 'padding.*offset' cpp/tensorrt_llm/thop/trtllmGenQKVProcessOp.cpp

Repository: NVIDIA/TensorRT-LLM

Length of output: 498


🏁 Script executed:

# Find Python/callers that allocate padding_offset tensors
find cpp/tensorrt_llm -name "*.py" -type f | head -20

Repository: NVIDIA/TensorRT-LLM

Length of output: 248


🏁 Script executed:

# Search for calls to build_decoder_info in Python to understand buffer allocation
rg -n 'build_decoder_info\|padding_offset' --type=py -B 3 -A 3 | head -100

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Search for buffer allocations/sizing related to padding_offset in all C++ files
rg -n 'padding.*offset.*size|numTokens' cpp/tensorrt_llm/common/attentionOp.cpp -B 3 -A 3 | head -150

Repository: NVIDIA/TensorRT-LLM

Length of output: 5155


🏁 Script executed:

# Verify the padding_offset sizing in context - compare all three buffers side by side
sed -n '1446,1493p' cpp/tensorrt_llm/common/attentionOp.cpp

Repository: NVIDIA/TensorRT-LLM

Length of output: 3463


🏁 Script executed:

# Check if there's any bounds guard or null check for paddingOffsets in MTP paths
rg -n 'multi_query_tokens\|removePadding.*true' cpp/tensorrt_llm/kernels/xqaDispatcher.cpp -B 5 -A 5

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Check how paddingOffsets is set in xqaDispatcher when multi_query_tokens is true
sed -n '350,400p' cpp/tensorrt_llm/kernels/xqaDispatcher.cpp

Repository: NVIDIA/TensorRT-LLM

Length of output: 3130


🏁 Script executed:

# Check more carefully: what is the relationship between batch_size*input_seq_length and numTokens?
rg -n 'batch_size.*input_seq_length\|num_tokens\|removePadding' cpp/tensorrt_llm/common/attentionOp.cpp | grep -E '(batch_size|input_seq_length|num_tokens)' | head -30

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Check how paddingOffsets is actually used and what bounds apply
# Look at all places where paddingOffsets is set or used
rg -n 'paddingOffsets' cpp/tensorrt_llm/kernels/gptKernels.cu -B 2 -A 8

Repository: NVIDIA/TensorRT-LLM

Length of output: 1953


🏁 Script executed:

# Let's verify the buffer sizing issue more clearly - when paddingOffsets is allocated
# Look at context path vs generation path
grep -n 'padding_offset_size\|input_seq_length\|num_tokens' cpp/tensorrt_llm/common/attentionOp.cpp | head -40

Repository: NVIDIA/TensorRT-LLM

Length of output: 3755


Add bounds checking to compute_padding_offset lambda to prevent OOB writes when sum(seqQLengths) > numTokens in MTP paths.

The tokensInfo buffer is protected by an explicit if (idx < params.numTokens) check (line 248), but paddingOffsets and encoderPaddingOffsets write directly to paddingOffsets[seqBegin + tokenIdx] (line 212) without similar bounds validation. In multi-token generation scenarios where the sum of seqQLengths exceeds numTokens, this write can exceed the allocated buffer bounds. Add a bounds check mirroring the tokensInfo pattern:

if (seqBegin + tokenIdx < numTokens)
{
    paddingOffsets[seqBegin + tokenIdx] = paddingOffset;
}

Note: The bounds limit should come from buffer allocation parameters passed to the kernel (e.g., params.numTokens or equivalent).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tensorrt_llm/kernels/gptKernels.cu` around lines 242 - 244, The
compute_padding_offset lambda writes to paddingOffsets and encoderPaddingOffsets
without guarding against seqBegin + tokenIdx exceeding the allocated numTokens,
which can cause OOB in MTP paths; update the lambda (compute_padding_offset) to
check the destination index against params.numTokens (or the kernel's numTokens
parameter) before writing, e.g. only assign to paddingOffsets[seqBegin +
tokenIdx] and encoderPaddingOffsets[...] when (seqBegin + tokenIdx) <
params.numTokens, mirroring the existing tokensInfo check to keep kernel-safe
bounds.

Comment thread cpp/tensorrt_llm/kernels/gptKernels.cu Outdated
Comment on lines +248 to +251
if (idx < params.numTokens)
{
params.tokensInfo[idx] = make_int2(batchIdx, tokenIdx);
}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will it bring silent errors since there is no process when it falls into the else case.
In other word, the root cause of idx running out of numTokens is not fixed.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point — agreed the silent skip traded a hard fault for a worse failure mode (downstream attention reads stale (batchIdx, tokenIdxInSeq) and produces wrong outputs without crashing). Pushed 878fdc21eb which converts the bound check from a silent skip to printf(...) + __trap() for both tokensInfo and paddingOffsets/encoderPaddingOffsets. The kernel still cannot scribble past the [numTokens]-sized allocation, but if the upstream sum(seqQLengths) > numTokens invariant ever breaks the run aborts with (blockIdx, batchSize, capacity, seqBegin, seqEnd) printed instead of producing silent garbage.

I traced both call paths on Blackwell looking for the actual root cause — for the trtllm-gen Python path is_spec_decoding_enabled is forced to False at trtllm.py:1537 so spec_decoding_generation_lengths is never wired in and the kernel takes the safe fixed_q_seqlen branch. For the legacy XQA path the same bool gate at attentionOp.cpp:491 keeps the buffer null and multi_query_tokens=False. The one suspicious line I found is attentionOp.cpp:293 (total_num_input_tokens = mCpSize > 1 ? num_requests : num_tokens) which would set numTokens=num_requests while seqQLengths report runtime_draft_len+1 per slot, but I couldn't find an assignment that flips mCpSize > 1 from the PyTorch path. Couldn't reproduce the original Warp MMU fault on single-node 8×B200 or 8×B300 across multiple iterations either, so root cause remains open.

So treating this PR as the memory-safety mitigation only, leaving NVBug 6029882 open for the upstream metadata investigation. The trap+printf at least guarantees that if it recurs in CI we get an actionable signature instead of an opaque MMU fault.

bobboli added 3 commits April 29, 2026 03:47
…ndPaddingOffsets

Bound each tokensInfo write to the num_tokens-sized allocation so the
last block cannot scribble past the buffer when the packed seqQLengths
sum exceeds numTokens. The previous code relied on std::max(numTokens,
seqEnd) on the last block alone, which left out-of-bounds writes possible
in multi-token generation paths where the caller and the kernel disagree
on token accounting.

Observed manifestation: primary CUDA Warp MMU fault in the kernel with
secondary moeA2ADispatchKernel SIGTRAPs on peer ranks waiting for
dispatch metadata that never arrived -- stage
GB200-8_GPUs-2_Nodes-PyTorch-2 test
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_mtp]. Core-dump triage attributed the fault to
the last-block std::max branch in gptKernels.cu.

Also unwaive the test.

Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
Mirror the tokensInfo bound (added in cd62807) on the
compute_padding_offset lambda. paddingOffsets and encoderPaddingOffsets
are also [numTokens]-sized per BuildDecoderInfoParams; the lambda now
takes the capacity explicitly and skips writes past it. Current callers
pass paddingOffsets=nullptr on the multi-token gen path that triggered
NVBug 6029882, but the kernel should not depend on that for safety.

Note: not validated locally because the dev container is currently
missing cmake / TensorRT / libnuma-dev that the build setup expects;
relying on CI to validate this mechanical follow-up.

Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
…ts OOB

Per review feedback: the silent per-write clamp added in the previous
commits prevents memory corruption but trades a hard CUDA fault for
silent dropped writes, which is worse for inference correctness because
downstream attention reads stale (batchIdx, tokenIdxInSeq) entries.

Convert each idx-bound check from a silent skip into a one-line printf
diagnostic plus __trap(). The kernel still cannot scribble past the
[numTokens]-sized allocation, but if the upstream metadata bug ever
surfaces (sum(seqQLengths) > numTokens) it now aborts the run with
actionable context (blockIdx, batchSize, capacity, seqBegin, seqEnd)
rather than producing wrong outputs.

Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
@bobboli bobboli force-pushed the fix/nvbug-6029882-tokens-info-bounds branch from a6e035e to 878fdc2 Compare April 29, 2026 06:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants